{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from einops import rearrange, repeat\n",
    "from flash_attn import flash_attn_with_kvcache, flash_attn_varlen_func\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def attention_ref(\n",
    "    q,\n",
    "    k,\n",
    "    v,\n",
    "):\n",
    "    \"\"\"\n",
    "    Arguments:\n",
    "        q: (batch_size, seqlen_q, nheads, head_dim)\n",
    "        k: (batch_size, seqlen_k, nheads_k, head_dim)\n",
    "        v: (batch_size, seqlen_k, nheads_k, head_dim)\n",
    "    Output:\n",
    "        output: (batch_size, seqlen_q, nheads, head_dim)\n",
    "        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout\n",
    "    \"\"\"\n",
    "    dtype_og = q.dtype\n",
    "\n",
    "    d = q.shape[-1]\n",
    "\n",
    "    scores = torch.einsum(\"bthd,bshd->bhts\", q / math.sqrt(d), k)\n",
    "    \n",
    "    attention = torch.softmax(scores, dim=-1).to(v.dtype)\n",
    "\n",
    "    output = torch.einsum(\"bhts,bshd->bthd\", attention, v)\n",
    "\n",
    "    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1\n",
    "seqlen_q = 1\n",
    "seqlen_kv = 1024\n",
    "num_heads = 32\n",
    "head_dim = 128\n",
    "device = \"cuda\"\n",
    "dtype = torch.float16\n",
    "seqlen_new = seqlen_q\n",
    "\n",
    "# Initialize tensors\n",
    "q = torch.randn(batch_size, seqlen_q, num_heads, head_dim, device=device, dtype=dtype)\n",
    "\n",
    "k_cache = torch.randn(batch_size, seqlen_kv, num_heads, head_dim, device=device, dtype=dtype)\n",
    "v_cache = torch.randn(batch_size, seqlen_kv, num_heads, head_dim, device=device, dtype=dtype)\n",
    "\n",
    "k = 10 * torch.rand(batch_size, seqlen_new, num_heads, head_dim, device=device, dtype=dtype)\n",
    "v = 15 * torch.rand(batch_size, seqlen_new, num_heads, head_dim, device=device, dtype=dtype) - 7.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flashattn vs pytorch: 2.187490463256836e-05\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Compute flash attention with kvcache\n",
    "out_flashattn = flash_attn_with_kvcache(q, k_cache, v_cache)\n",
    "\n",
    "# Reference attention computation\n",
    "out_ref, _ = attention_ref(q, k_cache, v_cache)\n",
    "\n",
    "# Print differences\n",
    "print(f\"flashattn vs pytorch: {(out_flashattn - out_ref).abs().mean().item()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flashattn vs pytorch without append KV: 2.187490463256836e-05\n",
      "flashattn vs pytorch with append KV: 0.474365234375\n"
     ]
    }
   ],
   "source": [
    "k_cache_new = torch.cat([k_cache, k], dim=1)\n",
    "v_cache_new = torch.cat([v_cache, v], dim=1)\n",
    "\n",
    "cache_seqlens = torch.tensor([seqlen_kv], dtype=torch.int32, device=device)\n",
    "out_flashattn_new = flash_attn_with_kvcache(q, k_cache, v_cache, k, v, cache_seqlens=cache_seqlens)\n",
    "\n",
    "out_ref, _ = attention_ref(q, k_cache, v_cache)\n",
    "print(f\"flashattn vs pytorch without append KV: {(out_flashattn_new - out_ref).abs().mean().item()}\")   \n",
    "\n",
    "out_ref_new, _ = attention_ref(q, k_cache_new, v_cache_new)\n",
    "print(f\"flashattn vs pytorch with append KV: {(out_flashattn_new - out_ref_new).abs().mean().item()}\")   \n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bitattn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
